#!/usr/bin/python2
#-*-python-*-

import tempfile
import os
import sys
import numpy
import cf

'''
Units tests for the cf package.

'''

print '\n--------------------------------------------------------------------'
print 'TEST: Set chunk size:'

# Save original chunksize
original_chunksize = cf.CHUNKSIZE()

cf.CHUNKSIZE(60)
print 'CHUNKSIZE reset to',cf.CHUNKSIZE()

print '\n--------------------------------------------------------------------'
print "TEST: Create a field:"

f = cf.Field()
f.space = cf.Space()
 
f.standard_name = 'eastward_wind'

f.space['dim0'] = cf.Coordinate()
f.space['dim0'].standard_name = 'latitude'
f.space['dim0'].Data = cf.Data(numpy.arange(10.), 'degrees_north')
 
f.space['dim1'] = cf.Coordinate()
f.space['dim1'].standard_name = 'longitude'
f.space['dim1'].Data = cf.Data(numpy.arange(9.) + 20, 'degrees_east')
f.space['dim1'].Data[-1] += 5
f.space['dim1'].bounds = cf.CoordinateBounds()
bounds = numpy.array([f.space['dim1'].Data.array-0.5,
                      f.space['dim1'].Data.array+0.5]).transpose((1,0))
f.space['dim1'].bounds.Data = cf.Data(bounds, f.space['dim1'].Units)
f.space['dim1'].bounds.Data[-2,1] = 30
f.space['dim1'].bounds.Data[-1,:] = [30, 36]

f.space['dim2'] = cf.Coordinate()
f.space['dim2'].standard_name = 'atmosphere_hybrid_height_coordinate'
f.space['dim2'].Data = cf.Data(1.5)
f.space['dim2'].bounds = cf.CoordinateBounds()
f.space['dim2'].bounds.Data = cf.Data([1, 2.], f.space['dim2'].Units)
f.space['dim2'].transforms = ['trans0']

f.space['aux0'] = cf.Coordinate()
f.space['aux0'].id = 'atmosphere_hybrid_height_coordinate_ak'
f.space['aux0'].Data = cf.Data(10., 'm')
f.space['aux0'].bounds = cf.CoordinateBounds()
f.space['aux0'].bounds.Data = cf.Data([5, 15.], f.space['aux0'].Units)

f.space['aux1'] = cf.Coordinate()
f.space['aux1'].id = 'atmosphere_hybrid_height_coordinate_bk'
f.space['aux1'].Data = cf.Data(20.)
f.space['aux1'].bounds = cf.CoordinateBounds()
f.space['aux1'].bounds.Data = cf.Data([14, 26.], f.space['aux1'].Units)

f.space.dimensions = {'data': ['dim0', 'dim1'],
                      'dim0': ['dim0'],
                      'dim1': ['dim1'],
                      'dim2': ['dim2'],
                      'aux0': ['dim2'],
                      'aux1': ['dim2'],
                      }

f.space.dimension_sizes = {'dim0': 10, 'dim1': 9, 'dim2': 1}

f.Data = cf.Data(numpy.arange(90.).reshape(10,9), 'm s-1')

f.finalize()

orog = f.copy()
orog.standard_name = 'surface_altitude'
orog.Data = cf.Data(f.array*2, 'm')
orog.squeeze()
orog.space.squeeze('dim2')
orog.transpose([1,0])
orog.finalize()
t = cf.Transform(a='aux0', b='aux1', orog=orog)
t.name = 'atmosphere_hybrid_height_coordinate'

f.space.transforms['trans0'] = t

print cf.dump(f)
print '-----------------'

# Ancillary variables
tmp = f.copy()
del tmp.coord('atmosphere_hybrid_height_coordinate').transforms
tmp.space.transforms = {}
tmp.space.remove_coordinate('aux0')
tmp.space.remove_coordinate('aux1')

f.ancillary_variables = cf.AncillaryVariables()

g = tmp.copy()
g.transpose([1,0])
g.standard_name = 'ancillary0'
g *= 0.01
f.ancillary_variables.append(g) 

g = tmp.copy()
g.space.squeeze('dim2')
g.standard_name = 'ancillary1'
g *= 0.01
f.ancillary_variables.append(g) 

g = tmp.copy()
g = g.subspace[0]
g.squeeze(from_space=True)
g.standard_name = 'ancillary2'
g *= 0.001
f.ancillary_variables.append(g)

g = tmp.copy()
g = g.subspace[..., 0]
g.squeeze(from_space=True)
g.space.squeeze('dim2')
g.standard_name = 'ancillary3'
g *= 0.001
f.ancillary_variables.append(g)

#f.cell_methods = cf.CellMethods('height: mean area: mean')
#f.cell_methods[0].dim  = ['dim2']
#
#f.cell_methods[1].dim  = ['dim0', 'dim1']
#f.cell_methods[1]['name']     = ['latitude'               , 'longitude']
#f.cell_methods[1]['interval'] = [cf.Data(0.1, 'degrees_N'), cf.Data(0.2, 'degrees_E')]

print 'OK'

print '\n--------------------------------------------------------------------'
print 'TEST: Print a dump of the field:'
print repr(f)

print f.dump()

print '\n--------------------------------------------------------------------'
print 'TEST: Print CF properties:'
print f.properties

print '\n--------------------------------------------------------------------'
print "TEST: Shape of the partition array:"
print '(pndim, psize, pshape) =', (f.Data.partitions.pndim,
                                   f.Data.partitions.psize,
                                   f.Data.partitions.pshape)

print '\n--------------------------------------------------------------------'
print "TEST: Non-weighted collapse:"

axes='latitude'
c = cf.collapse(f, 'mean', axes=axes, weights={})
print c
expected = numpy.array([numpy.ma.average(f.array[:,i]) for i in range(9)])
ok = numpy.ma.allclose(c.array.flatten(), expected)
if not ok:    
    print c.array.flatten(), expected, c.array.flatten()-expected
    raise RuntimeError("Unweighted mean over %s didn't work" % repr(axes))

axes='longitude'
c = cf.collapse(f, 'mean', axes=axes, weights={})
print c
expected = numpy.array([numpy.ma.average(f.array[i,:]) for i in range(10)])
ok = numpy.ma.allclose(c.array.flatten(), expected)
if not ok:    
    #print c.array.flatten(), expected, c.array.flatten()-expected
    raise RuntimeError("Unweighted mean over %s didn't work" % repr(axes))

axes=['longitude', 'latitude']
c = cf.collapse(f, 'mean', axes=axes, weights={})
print c
expected = numpy.ma.average(f.array)
ok = numpy.ma.allclose(c.array.flatten(), expected)
if not ok:    
    #print c.array.flatten(), expected, c.array.flatten()-expected
    raise RuntimeError("Unweighted mean over %s didn't work" % repr(axes))

axes=None
c = cf.collapse(f, 'variance', weights={})
print c
mean = numpy.ma.average(f.array)
expected = numpy.ma.sum((f.array-mean)**2)/(1.0*f.array.size)
ok = numpy.ma.allclose(c.array.flatten(), expected)
if not ok:    
    #print c.array.flatten(), expected, c.array.flatten()-expected
    raise RuntimeError("Weighted biased variance over %s didn't work" % repr(axes))

axes=None
c = cf.collapse(f, 'variance', weights={}, unbiased=True)
print c
mean = numpy.ma.average(f.array)
expected = numpy.ma.sum((f.array-mean)**2)/(f.array.size-1)
ok = numpy.ma.allclose(c.array.flatten(), expected)
if not ok:    
    #print c.array.flatten(), expected, c.array.flatten()-expected
    raise RuntimeError("Weighted unbiased variance over %s didn't work" % repr(axes))

print '\n--------------------------------------------------------------------'
print "TEST: Weighted collapse:"

axes='latitude'
c = cf.collapse(f, 'mean', axes=axes)
print c
w = cf.tools_collapse.calc_weights(f.coord(axes), infer_bounds=True).array
expected = numpy.array([numpy.ma.average(f.array[:,i], weights=w) for i in range(9)])
ok = numpy.ma.allclose(c.array.flatten(), expected)
if not ok:    
    #print c.array.flatten(), expected, c.array.flatten()-expected
    raise RuntimeError("Weighted mean over %s didn't work" % repr(axes))

axes='longitude'
c = cf.collapse(f, 'mean', axes=axes)
print c
w = cf.tools_collapse.calc_weights(f.coord(axes), infer_bounds=True).array
expected = numpy.array([numpy.ma.average(f.array[i,:], weights=w) for i in range(10)])
ok = numpy.ma.allclose(c.array.flatten(), expected)
if not ok:    
    #print c.array.flatten(), expected, c.array.flatten()-expected
    raise RuntimeError("Weighted mean over %s didn't work" % repr(axes))

axes=['longitude', 'latitude']
c = cf.collapse(f, 'mean', axes=axes)
print c
w0 = cf.tools_collapse.calc_weights(f.coord('latitude'), infer_bounds=True).array.reshape(10,1)
w1 = cf.tools_collapse.calc_weights(f.coord('longitude'), infer_bounds=True).array.reshape(1,9)
w = w0*w1
expected = numpy.ma.average(f.array, weights=w)
ok = numpy.ma.allclose(c.array.flatten(), expected)
if not ok:    
    #print c.array.flatten(), expected, c.array.flatten()-expected
    raise RuntimeError("Weighted mean over %s didn't work" % repr(axes))

axes=None
c = cf.collapse(f, 'variance', weights='default')
print c
mean = numpy.ma.average(f.array, weights=w)
expected = numpy.ma.sum(w*(f.array-mean)**2)/numpy.ma.sum(w)
ok = numpy.ma.allclose(c.array.flatten(), expected)
if not ok:    
    #print c.array.flatten(), expected, c.array.flatten()-expected
    raise RuntimeError("Weighted biased variance over %s didn't work" % repr(axes))

axes=None
c = cf.collapse(f, 'variance', weights='default', unbiased=True)
print c
mean = numpy.ma.average(f.array, weights=w)
sow = numpy.ma.sum(w)
sow2 = numpy.ma.sum(w**2)
expected = numpy.ma.sum(w*(f.array-mean)**2)*(sow/(sow**2-sow2))
ok = numpy.ma.allclose(c.array.flatten(), expected)
if not ok:    
    #print c.array.flatten(), expected, c.array.flatten()-expected
    raise RuntimeError("Weighted unbiased variance over %s didn't work" % repr(axes))

print '\n--------------------------------------------------------------------'
print "TEST: Cell methods collapse:"

cell_methods='longitude: mean latitude: max'
c = cf.collapse(f, cell_methods, weights='default')
print c
d = cf.collapse(f, 'mean', axes='longitude')
expected = cf.collapse(d, 'max', axes='latitude')
ok = numpy.ma.allclose(c.array.flatten(), expected.array.flatten())
if not ok:    
    #print c.array.flatten(), expected, c.array.flatten()-expected
    raise RuntimeError("Unweighted %s didn't work" % repr(cell_methods))


# Add a cell measure to the field
f.space['cm0'] = cf.CellMeasure()
f.space['cm0'].measure = 'area'
f.space['cm0'].Data = cf.Data(numpy.arange(90.).reshape(9, 10)*1234, 'km 2')
f.space.dimensions['cm0']= ['dim1', 'dim0']

print '\n--------------------------------------------------------------------'
print 'TEST: Write the field to disk:'
tmpfile = tempfile.mktemp('.nc')
print 'tmpfile=', tmpfile
cf.write(f, tmpfile)

print '\n--------------------------------------------------------------------'
print 'TEST: Read the field from disk:'
g = cf.read(tmpfile)[0]
try:
    del g.history
except AttributeError:
    pass

print g.dump()

print '\n--------------------------------------------------------------------'
print '\nasdasdasdasd'
c = cf.Comparison('set', [0,3,4,5])
print 'f.Data.dtype=',f.Data.dtype
a = (f == c)
print repr(a)
print a.array

print '\n--------------------------------------------------------------------'
print "TEST: Check the equality function:"
if not cf.equals(g, g.copy(), traceback=True):
    raise RuntimeError("Field is not equal to itself")

if not cf.equals(f, g, traceback=True):
    raise RuntimeError("Field is not equal to itself read back in")

print 'OK'

print '\n--------------------------------------------------------------------'
print "TEST: +, -, *, /, **:"
h = g.copy()
h **= 2
h **= 0.5
h *= 10
h /= 10.
h += 100
h -= 100
h = h ** 3
h = h ** (1/3.)
h = h * 1000
h = h / 1000.
h = h + 10000
h = h - 10000
if not cf.equals(g, h, traceback=True):
    raise RuntimeError("+, -, *, / or ** didn't work")

print '\n--------------------------------------------------------------------'
print "TEST: tranpose:"
h.transpose((1, 0))
h.transpose((1, 0))
h.transpose(('longitude', 'latitude'))
h.transpose(('latitude', 'longitude'))
if not cf.equals(g, h, traceback=True):
    raise RuntimeError("Tranpose didn't work")

print '\n--------------------------------------------------------------------'
print "TEST: flip:"
h.flip((1, 0))
h.flip((1, 0))
if not cf.equals(g, h, traceback=True):
    raise RuntimeError("Reversing dimensions' directions didn't work")

print '\n--------------------------------------------------------------------'
print "TEST: expand_dims and squeeze:"
h.expand_dims()
h.expand_dims()
new_dims = h.Data.order[:2]
h.squeeze()
h.space.squeeze(new_dims[0])
h.space.squeeze(new_dims[1])
if not cf.equals(g, h, traceback=True):
    raise RuntimeError("expand_dims or squeeze didn't work")

print '\n--------------------------------------------------------------------'
print "TEST: Access the field's data as a numpy array:"
print g.array

print '\n--------------------------------------------------------------------'
print "TEST: Access the field's coordinates' data arrays:"
print 'latitude :', g.coord('lat').array
print 'longitude:', g.coord('lon').array

print '\n--------------------------------------------------------------------'
print 'TEST: Indices for a subspace defined by coordinates:'
print f.indices()
print f.indices(latitude = cf.lt(5), longitude = 27)
print f.indices(latitude = cf.lt(5), longitude = 27, atmosphere_hybrid_height_coordinate=1.5)

print '\n--------------------------------------------------------------------'
print 'TEST: Subspace the field:'
print g.subspace(latitude=cf.lt(5), longitude=27, atmosphere_hybrid_height_coordinate=1.5)

print '\n--------------------------------------------------------------------'
print 'TEST: Subspace the field:'
print g.subspace[..., 2:5].array

print '\n--------------------------------------------------------------------'
print 'TEST: Subspace the field:'
print g.subspace[9::-4, ...].array

print '\n--------------------------------------------------------------------'
print 'TEST: Create list of fields:'
fl = cf.FieldList([g, g, g, g])

print '\n--------------------------------------------------------------------'
print 'TEST: Write a list of fields to disk:'
cf.write((f, fl), tmpfile)
cf.write(fl, tmpfile)

print '\n--------------------------------------------------------------------'
print 'TEST: Read a list of fields from disk:'
fl = cf.read(tmpfile)
try:
    fl.delattr('history')
except AttributeError:
    pass

print repr(fl)

print '\n--------------------------------------------------------------------'
print 'TEST: Print all fields in the list:'
print fl

print '\n--------------------------------------------------------------------'
print 'TEST: Print the last field in the list:'
print fl[-1]

print '\n--------------------------------------------------------------------'
print 'TEST: Print the data of the last field in the list:'
print fl[-1].array

print '\n--------------------------------------------------------------------'
print 'TEST: Modify the last field in the list:'
fl[-1] *= -1
print fl[-1].array

print '\n--------------------------------------------------------------------'
print 'TEST: Changing units\n:'
fl[-1].units = 'mm.s-1'
print fl[-1].array

print '\n--------------------------------------------------------------------'
print 'TEST: Combine fields not in place:'
g = fl[-1] - fl[-1]
print g.array

print '\n--------------------------------------------------------------------'
print 'TEST: Combine field with a size 1 Data object:'
g += cf.Data([[[[[1.5]]]]], 'cm.s-1')
print g.array

print '\n--------------------------------------------------------------------'
print "TEST: Setting data array elements to a scalar with subspace[]:"
g.subspace[...] = 0
print g
g.subspace[3:7, 2:5] = -1
print g.array,'\n'
g.subspace[6:2:-1, 4:1:-1] = numpy.array(-1)
print g.array,'\n'
g.subspace[[0, 3, 8], [1, 7, 8]] = numpy.array([[[[-2]]]])
print g.array,'\n'
g.subspace[[8, 3, 0], [8, 7, 1]] = cf.Data(-3, None)
print g.array,'\n'
g.subspace[[7, 4, 1], slice(6, 8)] = [-4]
print g.array

print '\n--------------------------------------------------------------------'
print "TEST: Setting of (un)masked elements with setitem():"
g.subspace[::2, 1::2] = numpy.ma.masked
print g.array,'\n'
g.Data.to_memory(1)
print g.Data.partitions[0][1].data
g.setitem(99)
print g.array,'\n'
g.Data.to_memory(1)
print g.Data.partitions[0][1].data
g.setitem(2, masked=True)
print g.array,'\n'
g.Data.to_memory(1)
print g.Data.partitions[0][1].data
print '\n--------------'

g.setitem(numpy.ma.masked, indices=(slice(None, None, 2), slice(1, None, 2)))
print g.array,'\n'
g.Data.to_memory(1)
print g.Data.partitions[0][1].data
g.setitem([[-1]], masked=False)
print g.array,'\n'
g.Data.to_memory(1)
print g.Data.partitions[0][1].data
g.setitem(cf.Data(0, None))
print g.array,'\n'
g.Data.to_memory(1)
print g.Data.partitions[0][1].data

h = g.subspace[:3, :4]
h.setitem(-1)
h.setitem(2, (0, 2))
print  h.dump()
h.transpose([1, 0])
print h.array
h.flip([1, 0])
print h.array

g.setitem(h, (slice(None, 3), slice(None, 4)))
print g.array
h = g.subspace[:3, :4]
h.setitem(-1)
h.setitem(2, (0, 2))
g.setitem(h, (slice(None, 3), slice(None, 4)))
print g.array

#print '\n--------------------------------------------------------------------'
#print "TEST: Setting data array elements to array of size > 1:"
#g.subspace[...] = 0
#g.hardmask=True
#g.Data[3:9, 2:5] = numpy.arange(1, 19).reshape(6, 3)
#print g.array,'\n'
#g.subspace[8:2:-1, 4:1:-1] = numpy.arange(1, 19).reshape(6, 3)
#print g.array,'\n'
#g.subspace[8:2:-1, 4:1:-1] = numpy.arange(18, 0, -1).reshape(1, 1, 6, 3)
#print g.array,'\n'
#g.subspace[...] = 0
#g.subspace[[0,1,2,7,8,9],[0,1,8]] = numpy.arange(1, 19).reshape(1, 6, 3)
#print g.array,'\n'
#g.subspace[[9,8,7,2,1,0],[8,1,0]] = numpy.arange(1, 19).reshape(6, 3)
#print g.array,'\n'
#g.subspace[[9,8,7,2,1,0],[8,1,0]] = numpy.arange(1, 4)
#print g.array,'\n'
#g.subspace[[9,8,7,2,1,0],[8,1,0]] = numpy.arange(1, 4).reshape(1, 3)
#print g.array,'\n'
#g.subspace[[9,8,7,2,1,0],[8,1,0]] = numpy.arange(1, 7).reshape(6, 1)
#print g.array,'\n'
#g.subspace[[9,8,7,2,1,0],[8,1,0]] = numpy.arange(1, 7).reshape(1, 6, 1)
#print g.array

print '\n--------------------------------------------------------------------'
print "TEST: Make sure all partitions' data are in temporary files:"
g.Data.to_disk()
print g.Data.partitions.info('data')

print '\n--------------------------------------------------------------------'
print "TEST: Push partitions' data from temporary files into memory:"
g.Data.to_memory(regardless=True)
print g.Data.partitions.info('data')

print '\n--------------------------------------------------------------------'
print g.Data.partitions.info('data')
print "TEST: Push partitions' data from memory to temporary files:"
g.Data.to_disk()
print g.Data.partitions.info('data')

print '\n--------------------------------------------------------------------'
print "TEST: Iterate through array values:"
for x in f.Data.flat():
    print x,
print

print '\n--------------------------------------------------------------------'
print "TEST: Data any() and all():"
print f.Data.any()
print f.Data.all()

print '\n--------------------------------------------------------------------'
print 'TEST: Reset chunk size:'
cf.CHUNKSIZE(original_chunksize)
print 'CHUNKSIZE reset to',cf.CHUNKSIZE()

print '\n--------------------------------------------------------------------'
print "TEST: Remove temporary files:"
cf.partition._remove_temporary_files()

print
print '--------------------------------------------------------------------'
print 'Hooray! All tests passed for cf version', cf.__version__
print '--------------------------------------------------------------------'

# Tidy up the temporary file
os.remove(tmpfile)
